ScatterNd ================= 依据指定的索引 ``indices``,将更新值 ``updates`` 累加到输出张量 ``output`` 的对应位置。 .. math:: output[indices_i] = output[indices_i] + updates_i 输入: - **output** - 输出张量的起始地址(计算前作为基础值,计算后存储结果)。 - **output_shape** - 输出张量的形状数组。 - **output_ndim** - 输出张量的维度数。 - **indices** - 索引数据地址,其形状通常为 ``(num_slices, indices_depth)``。 - **indices_shape** - 索引张量的形状数组。 - **indices_ndim** - 索引张量的维度数。 - **updates** - 更新数据地址。 - **core_mask(int, 可选)** - 核掩码(仅适用于共享存储版本)。 输出: - **output** - 计算结果地址。 支持平台: ``FT78NE`` ``MT7004`` .. note:: - FT78NE 支持 int8, int16, int32, fp32, fp64, cplx64, cplx128 - MT7004 支持 fp16, fp32, int16, int32, cplx64 - 该算子在多核实现中由于涉及随机访存写,通常直接在 DDR 空间操作。 - 张量维度最大支持 8 维。 **共享存储版本:** .. c:function:: void i8_scatter_nd_s(int8_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int8_t* updates, int core_mask) .. c:function:: void i16_scatter_nd_s(int16_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int16_t* updates, int core_mask) .. c:function:: void i32_scatter_nd_s(int32_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int32_t* updates, int core_mask) .. c:function:: void hp_scatter_nd_s(half* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, half* updates, int core_mask) .. c:function:: void fp_scatter_nd_s(float* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, float* updates, int core_mask) .. c:function:: void dp_scatter_nd_s(double* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, double* updates, int core_mask) .. c:function:: void c64_scatter_nd_s(float* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, float* updates, int core_mask) .. c:function:: void c128_scatter_nd_s(double* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, double* updates, int core_mask) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 13 //FT78NE示例(共享存储) #include #include "78NE/utils.h" int main() { float *output = (float *)0xA0000000; // 基础输出张量在 DDR float *updates = (float *)0xB0000000; // 更新值在 DDR int *indices = (int *)0xC0000000; // 索引在 DDR int out_shape[] = {4, 4, 4}; int ind_shape[] = {5, 2}; int out_ndim = 3; int ind_ndim = 2; int core_mask = 0x0B; fp_scatter_nd_s(output, out_shape, out_ndim, indices, ind_shape, ind_ndim, updates, core_mask); return 0; } **私有存储版本:** .. c:function:: void i8_scatter_nd_p(int8_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int8_t* updates) .. c:function:: void i16_scatter_nd_p(int16_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int16_t* updates) .. c:function:: void i32_scatter_nd_p(int32_t* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, int32_t* updates) .. c:function:: void hp_scatter_nd_p(half* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, half* updates) .. c:function:: void fp_scatter_nd_p(float* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, float* updates) .. c:function:: void dp_scatter_nd_p(double* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, double* updates) .. c:function:: void c64_scatter_nd_p(float* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, float* updates) .. c:function:: void c128_scatter_nd_p(double* output, int* output_shape, int output_ndim, int* indices, int* indices_shape, int indices_ndim, double* updates) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 14 //MT7004 示例 #include int main() { float *output = (float *)0x10810000; float *updates = (float *)0x10820000; int *indices = (int *)0x10830000; int out_shape[] = {4, 4, 4}; int ind_shape[] = {5, 2}; int out_ndim = 3; int ind_ndim = 2; fp_scatter_nd_p(output, out_shape, out_ndim, indices, ind_shape, ind_ndim, updates); return 0; }